import numpy as np
import time
import socket
import pickle
import threading
import random

# -------------------------------
# WAN-Scale HDGL Lattice Parameters
# -------------------------------
NUM_NODES = 16          # total distributed nodes
NUM_CORES = 2           # cores per node
NUM_STRANDS = 8
SLOTS_PER_STRAND = 4
TOTAL_SLOTS = NUM_STRANDS * SLOTS_PER_STRAND

PHI = 1.6180339887
OMEGA_BASE = 1 / (PHI**np.arange(1, NUM_STRANDS+1))**7

HOP_INTERVAL = 0.1
FREQ_HOP_RANGES = [(100e3, 200e3), (200e3, 300e3), (300e3, 400e3)]
FS = 1_000_000
BLOCK_SIZE = 4096
t = np.arange(BLOCK_SIZE) / FS

PHASE_CORRECTION_GAIN = 0.1
FEEDBACK_GAIN = 0.05

# -------------------------------
# Node Lattice Initialization
# -------------------------------
def init_node():
    cores = []
    for c in range(NUM_CORES):
        lattice = np.random.uniform(0.5, 1.0, (NUM_STRANDS, SLOTS_PER_STRAND))
        phases = np.random.uniform(0, 2*np.pi, (NUM_STRANDS, SLOTS_PER_STRAND))
        weights = np.ones((NUM_STRANDS, SLOTS_PER_STRAND))
        cores.append({'lattice': lattice, 'phases': phases, 'weights': weights, 'omega': OMEGA_BASE*(1/(c+1))})
    return cores

node_lattices = [init_node() for _ in range(NUM_NODES)]

# -------------------------------
# Networking (UDP Broadcast for WAN Simulation)
# -------------------------------
UDP_PORT = 5005
BROADCAST_IP = "255.255.255.255"

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.bind(("", UDP_PORT))
sock.settimeout(0.001)

peers = {}  # addr -> {timestamp, lattice snapshot}

def broadcast_node(node_id, cores, gps_time):
    packet = {'node_id': node_id, 'cores': cores, 'timestamp': gps_time}
    sock.sendto(pickle.dumps(packet), (BROADCAST_IP, UDP_PORT))

def listen_for_peers():
    while True:
        try:
            data, addr = sock.recvfrom(16384)
            packet = pickle.loads(data)
            peers[packet['node_id']] = {'time': packet['timestamp'], 'cores': packet['cores']}
        except socket.timeout:
            continue
        except Exception as e:
            print(f"Network error: {e}")

listener_thread = threading.Thread(target=listen_for_peers, daemon=True)
listener_thread.start()

# -------------------------------
# Adaptive Phase & Weight Correction Across WAN
# -------------------------------
def adapt_to_peers():
    while True:
        for node_id, peer in peers.items():
            peer_cores = peer['cores']
            for c, peer_core in enumerate(peer_cores):
                for n_core in node_lattices[node_id % NUM_NODES][c:c+1]:
                    phase_diff = (peer_core['phases'] - n_core['phases']) % (2*np.pi)
                    n_core['phases'] += PHASE_CORRECTION_GAIN * phase_diff
                    n_core['weights'] = 0.9*n_core['weights'] + 0.1*peer_core['weights']
        time.sleep(0.05)

adapt_thread = threading.Thread(target=adapt_to_peers, daemon=True)
adapt_thread.start()

# -------------------------------
# RF Signal Generation Per Node
# -------------------------------
def get_hopped_frequencies():
    return np.array([random.uniform(*random.choice(FREQ_HOP_RANGES))
                     for _ in range(TOTAL_SLOTS * NUM_CORES * NUM_NODES)])

def measure_signal(rf_block):
    """Simulated measurement (could be ADC or SDR capture)."""
    return np.abs(np.sum(rf_block)) / len(rf_block)

def feedback_optimize(rf_block):
    for node in node_lattices:
        for core in node:
            lattice = core['lattice']
            phases = core['phases']
            weights = core['weights']
            for s in range(NUM_STRANDS):
                for i in range(SLOTS_PER_STRAND):
                    probe_phase = phases[s,i] + np.random.uniform(-0.05,0.05)
                    probe_amp = lattice[s,i] + np.random.uniform(-0.01,0.01)
                    temp_block = rf_block + probe_amp*np.exp(1j*probe_phase)
                    if measure_signal(temp_block) > measure_signal(rf_block):
                        phases[s,i] = probe_phase
                        lattice[s,i] = probe_amp

def generate_rf_block(freqs):
    rf_block = np.zeros(BLOCK_SIZE, dtype=np.complex64)
    for node in node_lattices:
        for core in node:
            lattice = core['lattice']
            phases = core['phases']
            weights = core['weights']
            omega = core['omega']
            lattice += 0.02*omega[:,None]
            weights[:] = 0.9*weights + 0.1*lattice
            phases += 0.05*lattice
            for idx in range(TOTAL_SLOTS):
                strand = idx // SLOTS_PER_STRAND
                slot = idx % SLOTS_PER_STRAND
                amp = lattice[strand, slot] / np.max(lattice)
                phi = phases[strand, slot]
                freq_offset = 50e3 * (lattice[strand, slot] - 0.5)
                carrier = np.exp(1j*(2*np.pi*(freqs[idx]+freq_offset)*t + phi))
                rf_block += amp * carrier
    return rf_block / np.max(np.abs(rf_block))

# -------------------------------
# WAN Node Cleanup
# -------------------------------
def prune_peers(timeout=5):
    while True:
        now = time.time()
        stale = [nid for nid, p in peers.items() if now-p['time']>timeout]
        for nid in stale: del peers[nid]
        time.sleep(1)

prune_thread = threading.Thread(target=prune_peers, daemon=True)
prune_thread.start()

# -------------------------------
# Main Loop
# -------------------------------
try:
    print("WAN-Scale Geo-Synchronized HDGL RF Streaming Started.")
    last_hop = time.time()
    freqs = get_hopped_frequencies()

    while True:
        gps_time = time.time()
        if time.time() - last_hop > HOP_INTERVAL:
            freqs = get_hopped_frequencies()
            last_hop = time.time()
        rf_block = generate_rf_block(freqs)
        feedback_optimize(rf_block)
        for node_id, node in enumerate(node_lattices):
            broadcast_node(node_id, node, gps_time)
        # SDR hardware output: sdr.write_samples(rf_block)
        time.sleep(BLOCK_SIZE/FS)

except KeyboardInterrupt:
    print("WAN-Scale HDGL Streaming stopped.")
